"""
Probe: GDP/capita confound in demographics → interest rates
Tests whether demographics predict rates independently of income levels.
"""
import sys
from pathlib import Path
import numpy as np
import pandas as pd

PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS","AUT","BEL","CAN","CHL","COL","CRI","CZE","DNK","EST",
    "FIN","FRA","DEU","GRC","HUN","ISL","IRL","ISR","ITA","JPN",
    "KOR","LVA","LTU","LUX","MEX","NLD","NZL","NOR","POL","PRT",
    "SVK","SVN","ESP","SWE","CHE","TUR","GBR","USA",
]

def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''

def run_model(df, dep_var, rhs_vars, label):
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    sub = df[cols].dropna()
    if len(sub) < 30:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[rhs_vars].values, sub["iso3"].values, sub["year"].values)
    print(f"\n  {label} (N={gls.n_obs}, {gls.n_countries} countries, R\u00b2={gls.r_squared:.4f})")
    res = {"label": label, "n": gls.n_obs, "nc": gls.n_countries, "r2": gls.r_squared}
    for i, v in enumerate(rhs_vars):
        s = stars(gls.pvalues[i])
        print(f"    {v:30s} {gls.beta[i]:10.4f} ({gls.se[i]:.4f}) {s}  p={gls.pvalues[i]:.4f}")
        res[f"{v}_coef"] = gls.beta[i]
        res[f"{v}_se"] = gls.se[i]
        res[f"{v}_p"] = gls.pvalues[i]
    return res

def main():
    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel)} obs, {panel['iso3'].nunique()} countries")
    
    # Find GDP/capita variable
    gdp_candidates = [c for c in panel.columns if 'gdp' in c.lower() and ('pc' in c.lower() or 'cap' in c.lower() or 'per' in c.lower())]
    print(f"GDP/capita candidates: {gdp_candidates}")
    
    gdp_var = None
    for c in ['gdp_pc_ppp', 'gdp_per_capita', 'gdp_pc', 'rgdp_pc']:
        if c in panel.columns:
            gdp_var = c
            break
    
    if gdp_var is None:
        # Try to construct from rgdpna and pop
        if 'rgdpna' in panel.columns and 'pop' in panel.columns:
            panel['gdp_pc_ppp'] = panel['rgdpna'] / panel['pop']
            gdp_var = 'gdp_pc_ppp'
            print("Constructed gdp_pc_ppp from rgdpna/pop")
        else:
            print("ERROR: No GDP/capita variable found")
            print("Available columns:", sorted(panel.columns.tolist()))
            return
    
    print(f"Using GDP/capita variable: {gdp_var}")
    
    # Create log GDP/capita and interaction
    panel['log_gdp_pc'] = np.log(panel[gdp_var].clip(lower=100))
    panel['Z1_x_log_gdp_pc'] = panel['Z_1'] * panel['log_gdp_pc']
    
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']
    z_vars = ['Z_1', 'Z_2', 'Z_3']
    
    dep_vars = []
    for dv in ['govt_bond_yield_10y', 'tbill_rate_3m', 'bond_yield_10y', 'lending_rate']:
        if dv in panel.columns:
            dep_vars.append(dv)
    print(f"Dependent variables found: {dep_vars}")
    
    all_results = []
    md_lines = ["# GDP/Capita Confound Probe: Demographics \u2192 Interest Rates\n"]
    md_lines.append("Does GDP/capita absorb the demographic signal on rates?\n")
    
    for dep_var in dep_vars[:2]:  # First two rate variables
        sep = '=' * 70
        print(f"\n{sep}")
        print(f"DEPENDENT VARIABLE: {dep_var}")
        print(f"{sep}")
        
        for sample_label, mask in [("Full", panel['iso3'].notna()),
                                     ("OECD", panel['iso3'].isin(OECD_38)),
                                     ("Non-OECD", ~panel['iso3'].isin(OECD_38))]:
            sub = panel[mask].copy()
            print(f"\n--- {sample_label} sample ({sub['iso3'].nunique()} countries) ---")
            
            # M1: Baseline (Z only)
            r1 = run_model(sub, dep_var, z_vars + controls, f"M1: Z baseline [{sample_label}]")
            
            # M2: Z + log_gdp_pc
            r2 = run_model(sub, dep_var, z_vars + ['log_gdp_pc'] + controls, f"M2: Z + GDP/pc [{sample_label}]")
            
            # M3: Horse race with interaction
            r3 = run_model(sub, dep_var, z_vars + ['log_gdp_pc', 'Z1_x_log_gdp_pc'] + controls, f"M3: Horse race [{sample_label}]")
            
            # Calculate attenuation
            if r1 and r2:
                z1_base = r1.get('Z_1_coef', 0)
                z1_gdp = r2.get('Z_1_coef', 0)
                if abs(z1_base) > 0.001:
                    atten = (1 - z1_gdp/z1_base) * 100
                    print(f"\n  >>> Z\u2081 attenuation when adding GDP/pc: {atten:.1f}%")
                    print(f"  >>> Z\u2081: {z1_base:.2f} \u2192 {z1_gdp:.2f}")
            
            for r in [r1, r2, r3]:
                if r:
                    r['dep_var'] = dep_var
                    r['sample'] = sample_label
                    all_results.append(r)
    
    # Build markdown summary
    md_lines.append("\n## Key Results\n")
    md_lines.append("| Dep Var | Sample | Model | Z\u2081 coef | Z\u2081 p | log_gdp_pc coef | log_gdp_pc p | Z\u2081\u00d7log_gdp_pc coef | Z\u2081\u00d7log_gdp_pc p | N | R\u00b2 |")
    md_lines.append("|:---|:---|:---|---:|---:|---:|---:|---:|---:|---:|---:|")
    
    for r in all_results:
        z1c = r.get('Z_1_coef', '')
        z1p = r.get('Z_1_p', '')
        gc = r.get('log_gdp_pc_coef', '')
        gp = r.get('log_gdp_pc_p', '')
        ic = r.get('Z1_x_log_gdp_pc_coef', '')
        ip = r.get('Z1_x_log_gdp_pc_p', '')
        
        z1_str = f"{z1c:.2f}{stars(z1p)}" if isinstance(z1c, float) else ""
        gp_str = f"{gc:.2f}{stars(gp)}" if isinstance(gc, float) else ""
        ip_str = f"{ic:.2f}{stars(ip)}" if isinstance(ic, float) else ""
        z1p_str = f"{z1p:.4f}" if isinstance(z1p, float) else ""
        gpp_str = f"{gp:.4f}" if isinstance(gp, float) else ""
        ipp_str = f"{ip:.4f}" if isinstance(ip, float) else ""
        
        md_lines.append(f"| {r.get('dep_var','')} | {r.get('sample','')} | {r['label'].split('[')[0].strip()} | {z1_str} | {z1p_str} | {gp_str} | {gpp_str} | {ip_str} | {ipp_str} | {r['n']} | {r['r2']:.3f} |")
    
    # Attenuation summary
    md_lines.append("\n## Attenuation Summary\n")
    md_lines.append("| Dep Var | Sample | Z\u2081 baseline | Z\u2081 + GDP/pc | Attenuation % |")
    md_lines.append("|:---|:---|---:|---:|---:|")
    
    for dep_var in dep_vars[:2]:
        for sample in ["Full", "OECD", "Non-OECD"]:
            baselines = [r for r in all_results if r.get('dep_var') == dep_var and r.get('sample') == sample and 'M1' in r['label']]
            gdp_added = [r for r in all_results if r.get('dep_var') == dep_var and r.get('sample') == sample and 'M2' in r['label']]
            if baselines and gdp_added:
                b = baselines[0].get('Z_1_coef', 0)
                g = gdp_added[0].get('Z_1_coef', 0)
                bp = baselines[0].get('Z_1_p', 1)
                gp2 = gdp_added[0].get('Z_1_p', 1)
                atten = (1 - g/b) * 100 if abs(b) > 0.001 else 0
                md_lines.append(f"| {dep_var} | {sample} | {b:.2f}{stars(bp)} | {g:.2f}{stars(gp2)} | {atten:.1f}% |")
    
    md_lines.append("\n*Attenuation >50% suggests GDP/capita substantially confounds demographics.*")
    md_lines.append("*Significant Z\u2081\u00d7log_gdp_pc interaction means demographics operate differently across income levels.*")
    
    out_path = TABLE_DIR / "gdp_capita_confound.md"
    with open(out_path, 'w') as f:
        f.write('\n'.join(md_lines))
    print(f"\nSaved: {out_path}")

if __name__ == "__main__":
    main()
